{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic usage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*`skorch`* is designed to maximize interoperability between `sklearn` and `pytorch`. The aim is to keep 99% of the flexibility of `pytorch` while being able to leverage most features of `sklearn`. Below, we show the basic usage of `skorch` and how it can be combined with `sklearn`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows you how to use the basic functionality of `skorch`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Table of contents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* [Definition of the pytorch module](#Definition-of-the-pytorch-module)\n",
    "* [Training a classifier](#Training-a-classifier-and-making-predictions)\n",
    "  * [Dataset](#A-toy-binary-classification-task)\n",
    "  * [pytorch module](#Definition-of-the-pytorch-classification-module)\n",
    "  * [Model training](#Defining-and-training-the-neural-net-classifier)\n",
    "  * [Inference](#Making-predictions,-classification)\n",
    "* [Training a regressor](#Training-a-regressor)\n",
    "  * [Dataset](#A-toy-regression-task)\n",
    "  * [pytorch module](#Definition-of-the-pytorch-regression-module)\n",
    "  * [Model training](#Defining-and-training-the-neural-net-regressor)\n",
    "  * [Inference](#Making-predictions,-regression)\n",
    "* [Saving and loading a model](#Saving-and-loading-a-model)\n",
    "  * [Whole model](#Saving-the-whole-model)\n",
    "  * [Only parameters](#Saving-only-the-model-parameters)\n",
    "* [Usage with an sklearn Pipeline](#Usage-with-an-sklearn-Pipeline)\n",
    "* [Callbacks](#Callbacks)\n",
    "* [Grid search](#Usage-with-sklearn-GridSearchCV)\n",
    "  * [Special prefixes](#Special-prefixes)\n",
    "  * [Performing a grid search](#Performing-a-grid-search)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(0);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a classifier and making predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A toy binary classification task"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load a toy classification task from `sklearn`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import make_classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X, y = make_classification(1000, 20, n_informative=10, random_state=0)\n",
    "X = X.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1000, 20), (1000,), 0.5)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape, y.shape, y.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Definition of the `pytorch` classification `module`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define a vanilla neural network with two hidden layers. The output layer should have 2 output units since there are two classes. In addition, it should have a softmax nonlinearity, because later, when calling `predict_proba`, the output from the `forward` call will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class ClassifierModule(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            num_units=10,\n",
    "            nonlin=F.relu,\n",
    "            dropout=0.5,\n",
    "    ):\n",
    "        super(ClassifierModule, self).__init__()\n",
    "        self.num_units = num_units\n",
    "        self.nonlin = nonlin\n",
    "        self.dropout = dropout\n",
    "\n",
    "        self.dense0 = nn.Linear(20, num_units)\n",
    "        self.nonlin = nonlin\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.dense1 = nn.Linear(num_units, 10)\n",
    "        self.output = nn.Linear(10, 2)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        X = self.nonlin(self.dense0(X))\n",
    "        X = self.dropout(X)\n",
    "        X = F.relu(self.dense1(X))\n",
    "        X = F.softmax(self.output(X), dim=-1)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining and training the neural net classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use `NeuralNetClassifier` because we're dealing with a classifcation task. The first argument should be the `pytorch module`. As additional arguments, we pass the number of epochs and the learning rate (`lr`), but those are optional.\n",
    "\n",
    "*Note*: To use the cuda backend, pass `use_cuda=True` as an additional argument."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from skorch.net import NeuralNetClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    "    # use_cuda=True,  # uncomment this to train with CUDA\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As in `sklearn`, we call `fit` passing the input data `X` and the targets `y`. By default, `NeuralNetClassifier` makes a `StratifiedKFold` split on the data (80/20) to track the validation loss. This is shown, as well as the train loss and the accuracy on the validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Automatic pdb calling has been turned ON\n"
     ]
    }
   ],
   "source": [
    "pdb on"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.6868\u001b[0m       \u001b[32m0.6000\u001b[0m        \u001b[35m0.6740\u001b[0m  0.0793\n",
      "      2        \u001b[36m0.6706\u001b[0m       \u001b[32m0.6400\u001b[0m        \u001b[35m0.6617\u001b[0m  0.0686\n",
      "      3        \u001b[36m0.6637\u001b[0m       \u001b[32m0.6650\u001b[0m        \u001b[35m0.6504\u001b[0m  0.0541\n",
      "      4        \u001b[36m0.6548\u001b[0m       \u001b[32m0.7000\u001b[0m        \u001b[35m0.6418\u001b[0m  0.0535\n",
      "      5        \u001b[36m0.6340\u001b[0m       \u001b[32m0.7100\u001b[0m        \u001b[35m0.6272\u001b[0m  0.0539\n",
      "      6        \u001b[36m0.6219\u001b[0m       \u001b[32m0.7150\u001b[0m        \u001b[35m0.6124\u001b[0m  0.0574\n",
      "      7        \u001b[36m0.6058\u001b[0m       0.7100        \u001b[35m0.5980\u001b[0m  0.0530\n",
      "      8        \u001b[36m0.5964\u001b[0m       \u001b[32m0.7200\u001b[0m        \u001b[35m0.5875\u001b[0m  0.0646\n",
      "      9        \u001b[36m0.5901\u001b[0m       0.7100        \u001b[35m0.5760\u001b[0m  0.0572\n",
      "     10        \u001b[36m0.5716\u001b[0m       \u001b[32m0.7250\u001b[0m        \u001b[35m0.5651\u001b[0m  0.0460\n",
      "     11        \u001b[36m0.5633\u001b[0m       0.7250        \u001b[35m0.5580\u001b[0m  0.0471\n",
      "     12        0.5652       \u001b[32m0.7300\u001b[0m        \u001b[35m0.5529\u001b[0m  0.0453\n",
      "     13        \u001b[36m0.5462\u001b[0m       \u001b[32m0.7350\u001b[0m        \u001b[35m0.5426\u001b[0m  0.0500\n",
      "     14        \u001b[36m0.5407\u001b[0m       0.7300        \u001b[35m0.5407\u001b[0m  0.0448\n",
      "     15        \u001b[36m0.5360\u001b[0m       0.7300        \u001b[35m0.5373\u001b[0m  0.0464\n",
      "     16        0.5517       \u001b[32m0.7400\u001b[0m        \u001b[35m0.5328\u001b[0m  0.0448\n",
      "     17        \u001b[36m0.5351\u001b[0m       \u001b[32m0.7450\u001b[0m        \u001b[35m0.5277\u001b[0m  0.0460\n",
      "     18        \u001b[36m0.5280\u001b[0m       0.7400        \u001b[35m0.5260\u001b[0m  0.0530\n",
      "     19        \u001b[36m0.5148\u001b[0m       0.7450        0.5264  0.0583\n",
      "     20        0.5309       0.7400        \u001b[35m0.5210\u001b[0m  0.0740\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<skorch.net.NeuralNetClassifier at 0x7fb203e05668>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also, as in `sklearn`, you may call `predict` or `predict_proba` on the fitted model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Making predictions, classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = net.predict(X[:5])\n",
    "y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.54967159,  0.45032838],\n",
       "       [ 0.7842356 ,  0.2157644 ],\n",
       "       [ 0.67652136,  0.32347867],\n",
       "       [ 0.88522649,  0.1147735 ],\n",
       "       [ 0.68577141,  0.31422859]], dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_proba = net.predict_proba(X[:5])\n",
    "y_proba"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a regressor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A toy regression task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0)\n",
    "X_regr = X_regr.astype(np.float32)\n",
    "y_regr = y_regr.astype(np.float32) / 100\n",
    "y_regr = y_regr.reshape(-1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1000, 20), (1000, 1), -6.4901485, 6.1545048)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_regr.shape, y_regr.shape, y_regr.min(), y_regr.max()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Note*: Regression currently requires the target to be 2-dimensional, hence the need to reshape. This should be fixed with an upcoming version of pytorch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Definition of the `pytorch` regression `module`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, define a vanilla neural network with two hidden layers. The main difference is that the output layer only has one unit and does not apply a softmax nonlinearity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class RegressorModule(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            num_units=10,\n",
    "            nonlin=F.relu,\n",
    "    ):\n",
    "        super(RegressorModule, self).__init__()\n",
    "        self.num_units = num_units\n",
    "        self.nonlin = nonlin\n",
    "\n",
    "        self.dense0 = nn.Linear(20, num_units)\n",
    "        self.nonlin = nonlin\n",
    "        self.dense1 = nn.Linear(num_units, 10)\n",
    "        self.output = nn.Linear(10, 1)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        X = self.nonlin(self.dense0(X))\n",
    "        X = F.relu(self.dense1(X))\n",
    "        X = self.output(X)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining and training the neural net regressor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training a regressor is almost the same as training a classifier. Mainly, we use `NeuralNetRegressor` instead of `NeuralNetClassifier` (this is the same terminology as in `sklearn`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from skorch.net import NeuralNetRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net_regr = NeuralNetRegressor(\n",
    "    RegressorModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    "    # use_cuda=True,  # uncomment this to train with CUDA\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_loss     dur\n",
      "-------  ------------  ------------  ------\n",
      "      1        \u001b[36m4.6059\u001b[0m        \u001b[32m3.5860\u001b[0m  0.0264\n",
      "      2        \u001b[36m3.5021\u001b[0m        \u001b[32m1.3814\u001b[0m  0.0421\n",
      "      3        \u001b[36m1.1019\u001b[0m        \u001b[32m0.5334\u001b[0m  0.0436\n",
      "      4        \u001b[36m0.7071\u001b[0m        \u001b[32m0.2994\u001b[0m  0.0414\n",
      "      5        \u001b[36m0.5654\u001b[0m        0.4141  0.0248\n",
      "      6        \u001b[36m0.3179\u001b[0m        \u001b[32m0.1574\u001b[0m  0.0242\n",
      "      7        \u001b[36m0.2476\u001b[0m        0.1906  0.0269\n",
      "      8        \u001b[36m0.1302\u001b[0m        \u001b[32m0.1049\u001b[0m  0.0250\n",
      "      9        0.1373        0.1124  0.0240\n",
      "     10        \u001b[36m0.0728\u001b[0m        \u001b[32m0.0737\u001b[0m  0.0265\n",
      "     11        0.0839        \u001b[32m0.0727\u001b[0m  0.0247\n",
      "     12        \u001b[36m0.0435\u001b[0m        \u001b[32m0.0513\u001b[0m  0.0267\n",
      "     13        0.0508        \u001b[32m0.0483\u001b[0m  0.0268\n",
      "     14        \u001b[36m0.0279\u001b[0m        \u001b[32m0.0371\u001b[0m  0.0261\n",
      "     15        0.0322        \u001b[32m0.0335\u001b[0m  0.0275\n",
      "     16        \u001b[36m0.0193\u001b[0m        \u001b[32m0.0282\u001b[0m  0.0260\n",
      "     17        0.0224        \u001b[32m0.0247\u001b[0m  0.0257\n",
      "     18        \u001b[36m0.0148\u001b[0m        \u001b[32m0.0221\u001b[0m  0.0271\n",
      "     19        0.0167        \u001b[32m0.0198\u001b[0m  0.0264\n",
      "     20        \u001b[36m0.0122\u001b[0m        \u001b[32m0.0182\u001b[0m  0.0268\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<skorch.net.NeuralNetRegressor at 0x7fb203dcd390>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net_regr.fit(X_regr, y_regr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Making predictions, regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may call `predict` or `predict_proba` on the fitted model. For regressions, both methods return the same value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.52162153],\n",
       "       [-1.50998139],\n",
       "       [-0.90007448],\n",
       "       [-0.08845913],\n",
       "       [-0.52214217]], dtype=float32)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = net_regr.predict(X_regr[:5])\n",
    "y_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving and loading a model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Save and load either the whole model by using pickle or just the learned model parameters by calling `save_params` and `load_params`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving the whole model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "file_name = '/tmp/mymodel.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bbossan_dev/anaconda3/envs/skorch/lib/python3.6/site-packages/torch/serialization.py:158: UserWarning: Couldn't retrieve source code for container of type ClassifierModule. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    }
   ],
   "source": [
    "with open(file_name, 'wb') as f:\n",
    "    pickle.dump(net, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open(file_name, 'rb') as f:\n",
    "    new_net = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving only the model parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This only saves and loads the proper `module` parameters, meaning that hyperparameters such as `lr` and `max_epochs` are not saved. Therefore, to load the model, we have to re-initialize it beforehand."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net.save_params(file_name)  # a file handler also works"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# first initialize the model\n",
    "new_net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    ").initialize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "new_net.load_params(file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage with an `sklearn Pipeline`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is possible to put the `NeuralNetClassifier` inside an `sklearn Pipeline`, as you would with any `sklearn` classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pipe = Pipeline([\n",
    "    ('scale', StandardScaler()),\n",
    "    ('net', net),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Re-initializing module!\n",
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.6891\u001b[0m       \u001b[32m0.5550\u001b[0m        \u001b[35m0.6853\u001b[0m  0.1064\n",
      "      2        \u001b[36m0.6826\u001b[0m       \u001b[32m0.5600\u001b[0m        \u001b[35m0.6825\u001b[0m  0.0576\n",
      "      3        0.6873       \u001b[32m0.5900\u001b[0m        \u001b[35m0.6801\u001b[0m  0.0514\n",
      "      4        \u001b[36m0.6797\u001b[0m       \u001b[32m0.6000\u001b[0m        \u001b[35m0.6776\u001b[0m  0.0448\n",
      "      5        \u001b[36m0.6772\u001b[0m       \u001b[32m0.6150\u001b[0m        \u001b[35m0.6751\u001b[0m  0.0434\n",
      "      6        \u001b[36m0.6748\u001b[0m       \u001b[32m0.6200\u001b[0m        \u001b[35m0.6723\u001b[0m  0.0429\n",
      "      7        \u001b[36m0.6682\u001b[0m       0.6200        \u001b[35m0.6691\u001b[0m  0.0429\n",
      "      8        \u001b[36m0.6645\u001b[0m       0.6200        \u001b[35m0.6654\u001b[0m  0.0473\n",
      "      9        \u001b[36m0.6623\u001b[0m       \u001b[32m0.6300\u001b[0m        \u001b[35m0.6613\u001b[0m  0.0524\n",
      "     10        \u001b[36m0.6464\u001b[0m       0.6200        \u001b[35m0.6555\u001b[0m  0.0612\n",
      "     11        0.6471       0.6300        \u001b[35m0.6491\u001b[0m  0.0490\n",
      "     12        \u001b[36m0.6449\u001b[0m       \u001b[32m0.6600\u001b[0m        \u001b[35m0.6424\u001b[0m  0.0493\n",
      "     13        \u001b[36m0.6285\u001b[0m       0.6500        \u001b[35m0.6341\u001b[0m  0.0520\n",
      "     14        \u001b[36m0.6265\u001b[0m       0.6500        \u001b[35m0.6261\u001b[0m  0.0487\n",
      "     15        \u001b[36m0.6252\u001b[0m       0.6600        \u001b[35m0.6193\u001b[0m  0.0341\n",
      "     16        \u001b[36m0.6148\u001b[0m       \u001b[32m0.6750\u001b[0m        \u001b[35m0.6102\u001b[0m  0.0282\n",
      "     17        \u001b[36m0.6039\u001b[0m       \u001b[32m0.6850\u001b[0m        \u001b[35m0.6017\u001b[0m  0.0475\n",
      "     18        \u001b[36m0.5979\u001b[0m       \u001b[32m0.6900\u001b[0m        \u001b[35m0.5949\u001b[0m  0.0435\n",
      "     19        \u001b[36m0.5794\u001b[0m       \u001b[32m0.7000\u001b[0m        \u001b[35m0.5849\u001b[0m  0.0440\n",
      "     20        \u001b[36m0.5596\u001b[0m       \u001b[32m0.7050\u001b[0m        \u001b[35m0.5758\u001b[0m  0.0484\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "     steps=[('scale', StandardScaler(copy=True, with_mean=True, with_std=True)), ('net', <skorch.net.NeuralNetClassifier object at 0x7fb203e05668>)])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.39650354,  0.60349649],\n",
       "       [ 0.73950195,  0.26049808],\n",
       "       [ 0.72104084,  0.27895918],\n",
       "       [ 0.71111423,  0.2888858 ],\n",
       "       [ 0.66332674,  0.33667326]], dtype=float32)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_proba = pipe.predict_proba(X[:5])\n",
    "y_proba"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To save the whole pipeline, including the pytorch module, use `pickle`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adding a new callback to the model is straightforward. Below we show how to add a new callback that determines the area under the ROC (AUC) score."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from skorch.callbacks import EpochScoring"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is a scoring callback in skorch, `EpochScoring`, which we use for this. We have to specify which score to calculate. We have 3 choices:\n",
    "\n",
    "* Passing a string: This should be a valid `sklearn` metric. For a list of all existing scores, look [here](http://scikit-learn.org/stable/modules/classes.html#sklearn-metrics-metrics).\n",
    "* Passing `None`: If you implement your own `.score` method on your neural net, passing `scoring=None` will tell `skorch` to use that.\n",
    "* Passing a function or callable: If we want to define our own scoring function, we pass a function with the signature `func(model, X, y) -> score`, which is then used.\n",
    "\n",
    "Note that this works exactly the same as scoring in `sklearn` does."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For our case here, since `sklearn` already implements AUC, we just pass the correct string `'roc_auc'`. We should also tell the callback that higher scores are better (to get the correct colors printed below -- by default, lower scores are assumed to be better). Furthermore, we may specify a `name` argument for `EpochScoring`, and whether to use training data (by setting `on_train=True`) or validation data (which is the default)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "auc = EpochScoring(scoring='roc_auc', lower_is_better=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we pass the scoring callback to the `callbacks` parameter as a list and then call `fit`. Notice that we get the printed scores and color highlighting for free."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    "    callbacks=[auc],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    roc_auc    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ---------  ------------  -----------  ------------  ------\n",
      "      1     \u001b[36m0.5911\u001b[0m        \u001b[32m0.7204\u001b[0m       \u001b[35m0.5000\u001b[0m        \u001b[31m0.6948\u001b[0m  0.0590\n",
      "      2     \u001b[36m0.6524\u001b[0m        \u001b[32m0.6925\u001b[0m       \u001b[35m0.5300\u001b[0m        \u001b[31m0.6881\u001b[0m  0.0502\n",
      "      3     \u001b[36m0.6700\u001b[0m        \u001b[32m0.6867\u001b[0m       \u001b[35m0.6000\u001b[0m        \u001b[31m0.6857\u001b[0m  0.0321\n",
      "      4     \u001b[36m0.6854\u001b[0m        \u001b[32m0.6820\u001b[0m       \u001b[35m0.6400\u001b[0m        \u001b[31m0.6832\u001b[0m  0.0364\n",
      "      5     0.6829        \u001b[32m0.6801\u001b[0m       0.6050        \u001b[31m0.6812\u001b[0m  0.0377\n",
      "      6     0.6757        \u001b[32m0.6742\u001b[0m       0.6100        \u001b[31m0.6796\u001b[0m  0.0541\n",
      "      7     0.6808        0.6762       0.6100        \u001b[31m0.6776\u001b[0m  0.0528\n",
      "      8     0.6759        \u001b[32m0.6576\u001b[0m       0.6350        \u001b[31m0.6747\u001b[0m  0.0317\n",
      "      9     0.6813        0.6661       0.6350        \u001b[31m0.6707\u001b[0m  0.0525\n",
      "     10     \u001b[36m0.6903\u001b[0m        \u001b[32m0.6548\u001b[0m       \u001b[35m0.6450\u001b[0m        \u001b[31m0.6655\u001b[0m  0.0467\n",
      "     11     \u001b[36m0.6929\u001b[0m        \u001b[32m0.6500\u001b[0m       0.6400        \u001b[31m0.6611\u001b[0m  0.0495\n",
      "     12     0.6920        \u001b[32m0.6445\u001b[0m       \u001b[35m0.6500\u001b[0m        \u001b[31m0.6571\u001b[0m  0.0314\n",
      "     13     \u001b[36m0.7095\u001b[0m        \u001b[32m0.6372\u001b[0m       \u001b[35m0.6650\u001b[0m        \u001b[31m0.6509\u001b[0m  0.0390\n",
      "     14     \u001b[36m0.7155\u001b[0m        \u001b[32m0.6288\u001b[0m       \u001b[35m0.6700\u001b[0m        \u001b[31m0.6446\u001b[0m  0.0532\n",
      "     15     \u001b[36m0.7265\u001b[0m        \u001b[32m0.6268\u001b[0m       0.6700        \u001b[31m0.6390\u001b[0m  0.0494\n",
      "     16     \u001b[36m0.7398\u001b[0m        \u001b[32m0.6150\u001b[0m       \u001b[35m0.6900\u001b[0m        \u001b[31m0.6308\u001b[0m  0.0609\n",
      "     17     \u001b[36m0.7487\u001b[0m        0.6221       \u001b[35m0.7000\u001b[0m        \u001b[31m0.6246\u001b[0m  0.0540\n",
      "     18     0.7473        0.6168       \u001b[35m0.7250\u001b[0m        \u001b[31m0.6187\u001b[0m  0.0529\n",
      "     19     \u001b[36m0.7588\u001b[0m        \u001b[32m0.5945\u001b[0m       \u001b[35m0.7400\u001b[0m        \u001b[31m0.6100\u001b[0m  0.0522\n",
      "     20     \u001b[36m0.7664\u001b[0m        0.6000       \u001b[35m0.7650\u001b[0m        \u001b[31m0.6026\u001b[0m  0.0524\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<skorch.net.NeuralNetClassifier at 0x7fb203db76a0>"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For information on how to write custom callbacks, have a look at the [Advanced_Usage](https://nbviewer.jupyter.org/github/dnouri/skorch/blob/master/notebooks/Advanced_Usage.ipynb) notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Usage with sklearn `GridSearchCV`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Special prefixes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `NeuralNet` class allows to directly access parameters of the `pytorch module` by using the `module__` prefix. So e.g. if you defined the `module` to have a `num_units` parameter, you can set it via the `module__num_units` argument. This is exactly the same logic that allows to access estimator parameters in `sklearn Pipeline`s and `FeatureUnion`s."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This feature is useful in several ways. For one, it allows to set those parameters in the model definition. Furthermore, it allows you to set parameters in an `sklearn GridSearchCV` as shown below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition to the parameters prefixed by `module__`, you may access a couple of other attributes, such as those of the optimizer by using the `optimizer__` prefix (again, see below). All those special prefixes are stored in the `prefixes_` attribute:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "module, iterator_train, iterator_valid, optimizer, criterion, callbacks, dataset\n"
     ]
    }
   ],
   "source": [
    "print(', '.join(net.prefixes_))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Performing a grid search"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we show how to perform a grid search over the learning rate (`lr`), the module's number of hidden units (`module__num_units`), the module's dropout rate (`module__dropout`), and whether the SGD optimizer should use Nesterov momentum or not (`optimizer__nesterov`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    "    verbose=0,\n",
    "    optimizer__momentum=0.9,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "params = {\n",
    "    'lr': [0.05, 0.1],\n",
    "    'module__num_units': [10, 20],\n",
    "    'module__dropout': [0, 0.5],\n",
    "    'optimizer__nesterov': [False, True],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 3 folds for each of 16 candidates, totalling 48 fits\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.9s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.9s remaining:    0.0s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   1.1s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   2.4s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   2.4s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   1.0s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   1.0s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.9s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.9s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   1.0s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   1.0s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.9s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   1.0s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.9s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.9s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.7s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.7s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.9s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.8s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.9s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.9s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done  48 out of  48 | elapsed:   44.6s finished\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=3, error_score='raise',\n",
       "       estimator=<skorch.net.NeuralNetClassifier object at 0x7fb203dc3320>,\n",
       "       fit_params=None, iid=True, n_jobs=1,\n",
       "       param_grid={'lr': [0.05, 0.1], 'module__num_units': [10, 20], 'module__dropout': [0, 0.5], 'optimizer__nesterov': [False, True]},\n",
       "       pre_dispatch='2*n_jobs', refit=False, return_train_score=True,\n",
       "       scoring='accuracy', verbose=2)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gs.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.856 {'lr': 0.1, 'module__dropout': 0, 'module__num_units': 20, 'optimizer__nesterov': False}\n"
     ]
    }
   ],
   "source": [
    "print(gs.best_score_, gs.best_params_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Of course, we could further nest the `NeuralNetClassifier` within an `sklearn Pipeline`, in which case we just prefix the parameter by the name of the net (e.g. `net__module__num_units`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
